from __future__ import absolute_import

import logging
import torch
from torch import nn

from transformers.modeling_bert import BertPreTrainedModel, BertModel
from transformers.modeling_roberta import RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

import math

from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm



logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class GeLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return gelu(x)
    
class Highway(nn.Module):
    def __init__(self, size,out_size,num_layers, f):
        super(Highway, self).__init__()
        self.num_layers = num_layers
        self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.f = f
        self.linear_proj = nn.Linear(size,out_size)

    def forward(self, x):
        for layer in range(self.num_layers):
            gate = F.sigmoid(self.gate[layer](x))
            nonlinear = self.f(self.nonlinear[layer](x))
            linear = self.linear[layer](x)
            x = gate * nonlinear + (1 - gate) * linear
        x = self.linear_proj(self.f(x))
        return x
    
class NCEModel(BertPreTrainedModel):
    config_class = RobertaConfig
    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "roberta"
    
    def __init__(self, config,args=None):
        super(NCEModel, self).__init__(config)

        self.roberta = RobertaModel(config)
        self._dropout = nn.Dropout(config.hidden_dropout_prob)
        
        hid_dim = config.hidden_size
        
        # Answer Vector Generator: Input:  cls_ctxt, cls_ques -> ans_vec

#         nonlinear_f= GeLU()
        nonlinear_f =  nn.ReLU()

        self.cq_to_a = Highway(2*hid_dim,hid_dim,4,nonlinear_f)
        self.a_to_ah = Highway(hid_dim,hid_dim,2,nonlinear_f)
        
        self.ca_to_q = Highway(2*hid_dim,hid_dim,4,nonlinear_f)
        self.q_to_qh = Highway(hid_dim,hid_dim,2,nonlinear_f)
        
        self.qa_to_c = Highway(2*hid_dim,hid_dim,4,nonlinear_f)
        self.c_to_ch = Highway(hid_dim,hid_dim,2,nonlinear_f)
        
        
        self.debug = False
        self.hparams=args
        
        self.losstype = args.losstype
        if self.losstype == "l2":
            self.dist_type="l2"
        else:
            self.dist_type="cos"

        self.init_weights()
                
        self.cos = nn.CosineSimilarity()

    def nce_loss(self,pos_a,all_a):
        num_a = torch.exp(pos_a).squeeze(-1)
        den_a = torch.exp(all_a).sum(-1)    
        loss = -1 *torch.log(num_a/den_a)
        loss = loss.mean()
        return loss

    def forward(self,  # type: ignore
                # [cls] ctx [sep] q [sep] o_i [sep]
                ctx_ids,  # shape: batch_size, max_len
                ctx_mask,  # shape: batch_size, max_len
                qs_ids,  # shape: batch_size, max_len
                qs_mask,  # shape: batch_size, max_len
                ans_ids,  # shape: batch_size, max_len
                ans_mask,  # shape: batch_size, max_len
                ctx_all_ids =None,
                ctx_all_mask=None,
                qs_all_idx=None,
                qs_all_mask=None,
                ans_all_ids=None,
                ans_all_mask=None,
                loss_type=None
                ):
        loss_type = self.losstype if loss_type is None else loss_type

        if self.debug:
            logger.info(f"ctx_ids.size={ctx_ids.size()}")
            logger.info(f"ctx_mask.size={ctx_mask.size()}")

        batch_size = ctx_ids.shape[0]
        if ctx_all_ids is not None:
            options_num = ctx_all_ids.shape[1]

        _, cls_ctx = self.roberta(input_ids=ctx_ids, attention_mask=ctx_mask)
        _, cls_qs = self.roberta(input_ids=qs_ids, attention_mask=qs_mask)
        _, cls_ans = self.roberta(input_ids=ans_ids, attention_mask=ans_mask)
        
        cq2a = self.cq_to_a(torch.cat([cls_ctx,cls_qs],1))
        ahat = self.a_to_ah(cls_ans)
        
        ca2q = self.ca_to_q(torch.cat([cls_ctx,cls_ans],1))
        qhat = self.q_to_qh(cls_qs)
        
        qa2c = self.qa_to_c(torch.cat([cls_qs,cls_ans],1))
        chat = self.c_to_ch(cls_ctx)
        
        loss = 0 
        # Loss Type L2:
        if loss_type == "l2":
            cq2aloss = torch.norm(((cq2a * ahat)), 2, -1)
            ca2qloss = torch.norm(((ca2q * qhat)), 2, -1)
            qa2closs = torch.norm(((qa2c * chat)), 2, -1)
            loss = (cq2aloss*ca2qloss*qa2closs)
            loss = loss.mean()
            
        elif loss_type == "nce":
            assert ctx_all_ids is not None
            _, cls_actx = self.roberta(input_ids=ctx_all_ids.view(-1,ctx_all_ids.size(-1)), attention_mask=ctx_all_mask.view(-1,ctx_all_mask.size(-1)))
            _, cls_aqs  = self.roberta(input_ids=qs_all_idx.view(-1,qs_all_idx.size(-1)), attention_mask=qs_all_mask.view(-1,qs_all_mask.size(-1)))
            _, cls_aans = self.roberta(input_ids=ans_all_ids.view(-1,ans_all_ids.size(-1)), attention_mask=ans_all_mask.view(-1,ans_all_mask.size(-1)))
            alcq2a = self.cq_to_a(torch.cat([cls_actx,cls_aqs],1))
            alahat = self.a_to_ah(cls_aans)
            alca2q = self.ca_to_q(torch.cat([cls_actx,cls_aans],1))
            alqhat = self.q_to_qh(cls_aqs)
            alqa2c = self.qa_to_c(torch.cat([cls_aqs,cls_aans],1))
            alchat = self.c_to_ch(cls_actx)
            
            pos_a = self.cos(cq2a,ahat)
            pos_c = self.cos(qa2c,chat)
            pos_q = self.cos(ca2q,qhat)
            
            # print(ahat.shape,alcq2a.shape)
            all_a = self.cos(alcq2a,torch.cat([ahat.unsqueeze(1)]*options_num,dim=1).view(-1,ahat.size(-1)))
            all_c = self.cos(alqa2c,torch.cat([chat.unsqueeze(1)]*options_num,dim=1).view(-1,chat.size(-1)))
            all_q = self.cos(alchat,torch.cat([qhat.unsqueeze(1)]*options_num,dim=1).view(-1,qhat.size(-1)))
            
            loss_a = self.nce_loss(pos_a,all_a)
            loss_c = self.nce_loss(pos_c,all_c)
            loss_q = self.nce_loss(pos_q, all_q)
            
            loss = (loss_a+loss_c+loss_q)
            loss = loss.mean()

        return loss 
    
    def score(self,  # type: ignore
                # [cls] ctx [sep] q [sep] o_i [sep]
                ctx_ids,  # shape: batch_size, max_len
                ctx_mask,  # shape: batch_size, max_len
                qs_ids,  # shape: batch_size, max_len
                qs_mask,  # shape: batch_size, max_len
                ans_ids,  # shape: batch_size, options, max_len
                ans_mask,  # shape: batch_size, options, max_len
                dist_type="l2"):
        with torch.no_grad():
            dist_type=self.dist_type
            batch_size = ctx_ids.shape[0]
            options_num = ans_ids.shape[1]

            _, cls_ctx = self.roberta(input_ids=ctx_ids, attention_mask=ctx_mask)
            _, cls_qs = self.roberta(input_ids=qs_ids, attention_mask=qs_mask)

            flat_ans_ids = ans_ids.view(-1,ans_ids.size(-1))
            flat_ans_mask =  ans_mask.view(-1,ans_mask.size(-1))

            _, cls_ans = self.roberta(input_ids=flat_ans_ids, attention_mask=flat_ans_mask)

            cls_ctx = torch.cat([cls_ctx.unsqueeze(1)]*options_num,dim=1)
            cls_qs  = torch.cat([cls_qs.unsqueeze(1)]*options_num,dim=1)

            cls_ctx = cls_ctx.view(-1,cls_ctx.size(-1))
            cls_qs = cls_qs.view(-1, cls_qs.size(-1))

            cq2a = self.cq_to_a(torch.cat([cls_ctx,cls_qs],1))
            ahat = self.a_to_ah(cls_ans)

            ca2q = self.ca_to_q(torch.cat([cls_ctx,cls_ans],1))
            qhat = self.q_to_qh(cls_qs)

            qa2c = self.qa_to_c(torch.cat([cls_qs,cls_ans],1))
            chat = self.c_to_ch(cls_ctx)
            dist_scores=0
            if dist_type == "l2":
                cq2adist = torch.norm(((cq2a * ahat)), 2, -1,keepdim=True)
                ca2qdist = torch.norm(((ca2q * qhat)), 2, -1,keepdim=True)
                qa2cdist = torch.norm(((qa2c * chat)), 2, -1,keepdim=True)
                dist_scores = (cq2adist*ca2qdist*qa2cdist)

            elif dist_type == "cos":
                dist_a = 1-self.cos(cq2a,ahat)
                dist_c = 1-self.cos(qa2c,chat)
                dist_q = 1-self.cos(ca2q,qhat)
                dist_scores = (dist_a*dist_c*dist_q)
            
        return dist_scores.view(-1,options_num)
    
    
    
if __name__ == '__main__':
    model = NCEModel.from_pretrained("roberta-base").cuda()
    ctx_ids = torch.randint(100,[3,10]).cuda()
    ctx_mask= torch.randint(1,[3,10]).cuda()
    qs_ids  = torch.randint(100,[3,10]).cuda()
    qs_mask = torch.randint(1,[3,10]).cuda()
    ans_ids = torch.randint(100,[3,10]).cuda()
    ans_mask = torch.randint(1,[3,10]).cuda()
    ctx_all_ids =torch.randint(100,[3,4,10]).cuda()
    ctx_all_mask=torch.randint(1,[3,4,10]).cuda()
    qs_all_idx=torch.randint(100,[3,4,10]).cuda()
    qs_all_mask=torch.randint(1,[3,4,10]).cuda()
    ans_all_ids=torch.randint(100,[3,4,10]).cuda()
    ans_all_mask=torch.randint(1,[3,4,10]).cuda()
    
    # First test: L2Loss, only c,q,a
    loss = model.forward(ctx_ids,ctx_mask,qs_ids,qs_mask,ans_ids,ans_mask,loss_type="l2")
    print(f"L2Loss: Only CQA:{loss}",flush=True)
    # Second test: L2Score, only c,q,a
    score = model.score(ctx_ids,ctx_mask,qs_ids,qs_mask,ans_all_ids,ans_all_mask,dist_type="l2")
    print(f"L2Score: Only CQA:{score}",flush=True)
    
    # Third test: NCE, only c,q,a, all
    loss = model.forward(ctx_ids,ctx_mask,qs_ids,qs_mask,ans_ids,ans_mask,ctx_all_ids, ctx_all_mask, qs_all_idx, qs_all_mask, ans_all_ids, ans_all_mask, loss_type="l2")
    print(f"NCE: Only CQA All:{loss}",flush=True)
    
    # Forth test: NCE Score, only c,q,a, all
    score = model.score(ctx_ids,ctx_mask,qs_ids,qs_mask,ans_all_ids, ans_all_mask, dist_type="cos")
    print(f"NCE SCore: Only CQA All:{score}",flush=True)
    

        